{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Copyright 2020 NVIDIA Corporation. All Rights Reserved.\n",
    "#\n",
    "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "# you may not use this file except in compliance with the License.\n",
    "# You may obtain a copy of the License at\n",
    "#\n",
    "#     http://www.apache.org/licenses/LICENSE-2.0\n",
    "#\n",
    "# Unless required by applicable law or agreed to in writing, software\n",
    "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "# See the License for the specific language governing permissions and\n",
    "# limitations under the License.\n",
    "# =============================================================================="
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# !sync; echo 3 > /proc/sys/vm/drop_caches"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from time import time\n",
    "import re\n",
    "import glob\n",
    "import warnings\n",
    "\n",
    "# tools for data preproc/loading\n",
    "import torch\n",
    "# torch.multiprocessing.set_start_method(\"spawn\")\n",
    "# import rmm\n",
    "# import nvtabular as nvt\n",
    "# from nvtabular.ops import Normalize,  Categorify,  LogOp, FillMissing, Clip, get_embedding_sizes\n",
    "# from nvtabular.loader.torch import TorchAsyncItr, DLDataLoader\n",
    "# from nvtabular.utils import device_mem_size\n",
    "\n",
    "# tools for training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# define some information about where to get our data\n",
    "INPUT_DATA_DIR = os.environ.get('INPUT_DATA_DIR', '/raid/criteo/tests/crit_int_pq')\n",
    "OUTPUT_DATA_DIR = os.environ.get('OUTPUT_DATA_DIR', '/raid/criteo/tests/test_dask') # where we'll save our procesed data to\n",
    "BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 32768))\n",
    "AMP = os.environ.get(\"AMP\", \"true\") \n",
    "AMP = True if AMP.lower() in \"true\" else False\n",
    "SHUFFLE = os.environ.get(\"SHUFFLE\", False)\n",
    "PARTS_PER_CHUNK = int(os.environ.get('PARTS_PER_CHUNK', 2))\n",
    "NUM_TRAIN_DAYS = 23 # number of days worth of data to use for training, the rest will be used for validation\n",
    "\n",
    "# define our dataset schema\n",
    "CONTINUOUS_COLUMNS = ['I' + str(x) for x in range(1,14)]\n",
    "CATEGORICAL_COLUMNS =  ['C' + str(x) for x in range(1,27)]\n",
    "LABEL_COLUMNS = ['label']\n",
    "COLUMNS = CONTINUOUS_COLUMNS + CATEGORICAL_COLUMNS + LABEL_COLUMNS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "output_train_dir = os.path.join(OUTPUT_DATA_DIR, 'train/')\n",
    "output_valid_dir = os.path.join(OUTPUT_DATA_DIR, 'valid/')\n",
    "! mkdir -p $output_train_dir\n",
    "! mkdir -p $output_valid_dir"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_paths = glob.glob(os.path.join(output_train_dir, \"*.parquet\"))\n",
    "valid_paths = glob.glob(os.path.join(output_valid_dir, \"*.parquet\"))\n",
    "\n",
    "from pyarrow.parquet import read_metadata\n",
    "train_stats = read_metadata(train_paths[0])\n",
    "valid_stats = read_metadata(valid_paths[0])\n",
    "print(train_stats, valid_stats)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def batch_transform(batch, cont_cols=CONTINUOUS_COLUMNS, cat_cols=CATEGORICAL_COLUMNS, label_cols=LABEL_COLUMNS):\n",
    "    x_cat = torch.tensor(batch[sorted(cat_cols)].values,dtype=torch.int64)\n",
    "    x_cont = torch.tensor(batch[cont_cols].values, dtype=torch.float32)\n",
    "    y = torch.tensor(batch[label_cols[0]], dtype=torch.float32)\n",
    "    return x_cat, x_cont, y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "# Create Dataset\n",
    "class GenericDataset(Dataset):\n",
    "    def __init__(self, tar_file, transform=None):\n",
    "        self.frame = pd.read_parquet(tar_file)[:10000000]\n",
    "        self.transform = transform\n",
    "    \n",
    "    def __len__(self):\n",
    "        return self.frame.shape[0]\n",
    "    \n",
    "    def __getitem__(self, idx):\n",
    "        if torch.is_tensor(idx):\n",
    "            idx = idx.tolist()\n",
    "        rec = self.frame.iloc[idx]\n",
    "        if self.transform:\n",
    "            rec = self.transform(rec)\n",
    "        re\n",
    "        return rec"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_loader = DataLoader(GenericDataset(train_paths[0], transform=batch_transform), \n",
    "                          pin_memory=True,\n",
    "                          batch_size=BATCH_SIZE,\n",
    "                          num_workers=20,\n",
    "#                           collate_fn=collate,\n",
    "                          shuffle=SHUFFLE,\n",
    "                         )\n",
    "valid_loader = DataLoader(GenericDataset(valid_paths[0], transform=batch_transform), \n",
    "                          pin_memory=True, \n",
    "                          batch_size=BATCH_SIZE,\n",
    "                          num_workers=20,\n",
    "#                           collate_fn=collate,\n",
    "                          shuffle=SHUFFLE,\n",
    "                         )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "EMBEDDING_DROPOUT_RATE = 0.04\n",
    "DROPOUT_RATES = [0.001, 0.01]\n",
    "HIDDEN_DIMS = [1000, 500]\n",
    "LEARNING_RATE = 0.001\n",
    "EPOCHS = 1\n",
    "embeddings = {'C1': (7599500, 16),\n",
    " 'C10': (5345303, 16),\n",
    " 'C11': (561810, 16),\n",
    " 'C12': (242827, 16),\n",
    " 'C13': (11, 6),\n",
    " 'C14': (2209, 16),\n",
    " 'C15': (10616, 16),\n",
    " 'C16': (100, 16),\n",
    " 'C17': (4, 3),\n",
    " 'C18': (968, 16),\n",
    " 'C19': (15, 7),\n",
    " 'C2': (33521, 16),\n",
    " 'C20': (7838519, 16),\n",
    " 'C21': (2580502, 16),\n",
    " 'C22': (6878028, 16),\n",
    " 'C23': (298771, 16),\n",
    " 'C24': (11951, 16),\n",
    " 'C25': (97, 16),\n",
    " 'C26': (35, 12),\n",
    " 'C3': (17022, 16),\n",
    " 'C4': (7339, 16),\n",
    " 'C5': (20046, 16),\n",
    " 'C6': (4, 3),\n",
    " 'C7': (7068, 16),\n",
    " 'C8': (1377, 16),\n",
    " 'C9': (63, 16)}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from nvtabular.framework_utils.torch.models import Model\n",
    "from nvtabular.framework_utils.torch.utils import process_epoch\n",
    "model = Model(\n",
    "    embedding_table_shapes=embeddings,\n",
    "    num_continuous=len(CONTINUOUS_COLUMNS),\n",
    "    emb_dropout=EMBEDDING_DROPOUT_RATE,\n",
    "    layer_hidden_dims=HIDDEN_DIMS,\n",
    "    layer_dropout_rates=DROPOUT_RATES,\n",
    ").to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)\n",
    "def rmspe_func(y_pred, y):\n",
    "    \"Return y_pred and y to non-log space and compute RMSPE\"\n",
    "    y_pred, y = torch.exp(y_pred) - 1, torch.exp(y) - 1\n",
    "    pct_var = (y_pred - y) / y\n",
    "    return (pct_var**2).mean().pow(0.5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for epoch in range(EPOCHS):\n",
    "    start_train=time()\n",
    "    train_loss, y_pred, y = process_epoch(train_loader, \n",
    "                                          model, \n",
    "                                          train=True, \n",
    "                                          optimizer=optimizer,\n",
    "                                          #transform=batch_transform,\n",
    "                                          amp=AMP,\n",
    "                                          #device=device,\n",
    "                                         )\n",
    "    train_rmspe = rmspe_func(y_pred, y)\n",
    "    train_time=time() - start_train\n",
    "    y_pred = None\n",
    "    y = None\n",
    "    print(train_time)\n",
    "    start_valid=time()\n",
    "    valid_loss, y_pred, y = process_epoch(valid_loader,\n",
    "                                          model, \n",
    "                                          train=False, \n",
    "                                          optimizer=optimizer,\n",
    "                                          #transform=batch_transform,\n",
    "                                          amp=AMP,\n",
    "                                          #device=device,\n",
    "                                         )\n",
    "    valid_rmspe = rmspe_func(y_pred, y)\n",
    "    valid_time = time() - start_valid\n",
    "    y_pred = None\n",
    "    y = None\n",
    "    print(f\"Train:{train_time} + Valid:{valid_time} = EpochTotal:{train_time + valid_time}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "file_extension": ".py",
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.8"
  },
  "mimetype": "text/x-python",
  "name": "python",
  "npconvert_exporter": "python",
  "pygments_lexer": "ipython3",
  "version": 3
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
